import math
import random
import numpy as np
from numba import jit
from MinkowskiEngine.utils import sparse_quantize

class LidarProcessorConfig():
    X_max = 70
    X_min = -70
    Y_max = 70
    Y_min = -70
    Z_max = 2.5
    Z_min = -2.5
    dX = 0.5
    dY = 0.5
    dZ = 0.5

    X_SIZE = int((X_max - X_min) / dX)
    Y_SIZE = int((Y_max - Y_min) / dY)
    Z_SIZE = int((Z_max - Z_min) / dZ)

    lidar_dim = [X_SIZE, Y_SIZE, Z_SIZE]
    lidar_depth_dim = [X_SIZE, Y_SIZE, 1]

    x_bins = np.linspace(X_min, X_max + 1, X_SIZE + 1)
    y_bins = np.linspace(Y_min, Y_max + 1, Y_SIZE + 1)
    z_bins = np.linspace(Z_min, Z_max + 1, Z_SIZE + 1)


@jit(nopython=True)
def filter_lidar_by_boundary(lidar_data, X_max, X_min, Y_max, Y_min, Z_max, Z_min):
    lidar_data = lidar_data[(lidar_data[..., 0] < X_max)
                            & (lidar_data[..., 0] > X_min)
                            & (lidar_data[..., 1] < Y_max)
                            & (lidar_data[..., 1] > Y_min)
                            & (lidar_data[..., 2] <= Z_max)
                            & (lidar_data[..., 2] >= Z_min)]
    return lidar_data

def lidar_to_bev_v2(points, ego_length=0, ego_width=0, max_per_pixel=1, customize_z=False):
        points = np.asarray(points)

        if customize_z:
            Z_max = -0.3
            Z_min = -2.5
        else:
            Z_max = LidarProcessorConfig.Z_max
            Z_min = LidarProcessorConfig.Z_min
        points = points[
            (points[..., 0] < LidarProcessorConfig.X_max)
            & (points[..., 0] > LidarProcessorConfig.X_min)
            & (points[..., 1] < LidarProcessorConfig.Y_max)
            & (points[..., 1] > LidarProcessorConfig.Y_min)
            & (points[..., 2] <= Z_max)  # LidarPreprocessor.Z_max)
            & (points[..., 2] >= Z_min)  # LidarPreprocessor.Z_min)
            ]
        hist = \
            np.histogramdd(points, bins=(LidarProcessorConfig.x_bins,
                                         LidarProcessorConfig.y_bins,
                                         LidarProcessorConfig.z_bins
                                         ))[0]

        hist[hist > max_per_pixel] = max_per_pixel

        return hist

def pad_or_truncate(points, npoints):
    if points is None:
        return None
    diff = npoints - len(points)
    if diff > 0:
        points = np.concatenate((points,np.array(random.choices(points, k=abs(diff)))), axis=0)
    else:
        points = np.array(random.sample(list(points), k=abs(npoints)))
    return points

def pc_to_car_alignment(_pc):
    # Car alignment is a weird coordinate, upside down
    # Rest assured this is the right matrix, double checked
    # TODO not sure
    alignment = np.array([[0, 1, 0],
                          [-1, 0, 0],
                          [0, 0, -1]])
    return np.matmul(_pc, alignment)

def Sparse_Quantize(lidar):
        """
        quantize the lidar into sparse points
        """
        scale = 1.0
        dx, dy, dz = LidarProcessorConfig.dX, LidarProcessorConfig.dY, LidarProcessorConfig.dZ

        lidar = lidar[:, :3]
        Z_max = -0.3
        Z_min = -2.5
        lidar = filter_lidar_by_boundary(lidar,
                                         LidarProcessorConfig.X_max,
                                         LidarProcessorConfig.X_min,
                                         LidarProcessorConfig.Y_max,
                                         LidarProcessorConfig.Y_min,
                                         Z_max,
                                         Z_min)
        lidar = sparse_quantize(lidar, quantization_size=(dx * scale, dy * scale, dz * scale))
        lidar = np.array(lidar, dtype=float)
        lidar[:, 0] = lidar[:, 0] * dx * scale
        lidar[:, 1] = lidar[:, 1] * dy * scale
        lidar[:, 2] = lidar[:, 2] * dz * scale
        lidar = np.array(lidar)

        return lidar

class TransformMatrix_WorldCoords(object):
    def __init__(self, transform):
        rotation = transform.rotation
        translation = transform.location
        cy = math.cos(np.radians(rotation.yaw))
        sy = math.sin(np.radians(rotation.yaw))
        cr = math.cos(np.radians(rotation.roll))
        sr = math.sin(np.radians(rotation.roll))
        cp = math.cos(np.radians(rotation.pitch))
        sp = math.sin(np.radians(rotation.pitch))
        self.matrix = np.matrix(np.identity(4))
        self.matrix[0, 3] = translation.x
        self.matrix[1, 3] = translation.y
        self.matrix[2, 3] = translation.z
        self.matrix[0, 0] = (cp * cy)
        self.matrix[0, 1] = (cy * sp * sr - sy * cr)
        self.matrix[0, 2] = -(cy * sp * cr + sy * sr)
        self.matrix[1, 0] = (sy * cp)
        self.matrix[1, 1] = (sy * sp * sr + cy * cr)
        self.matrix[1, 2] = (cy * sr - sy * sp * cr)
        self.matrix[2, 0] = (sp)
        self.matrix[2, 1] = -(cp * sr)
        self.matrix[2, 2] = (cp * cr)
        self.rcw = self.matrix[:3,:3]
        self.tcw= self.matrix[:3, 3]


    def inversematrix(self):
        """Return the inverse transform."""
        return np.linalg.inv(self.matrix)
